{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mix Federated Learning - Logistic Regression\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial demonstrates how to do `Logistic Regression` with mix partitioned data.\n",
    "\n",
    "The following is an example of mix partitioning data.\n",
    "\n",
    "<img alt=\"dataframe.png\" src=\"../developer/algorithm/resources/mix_data.png\" width=\"600\">\n",
    "\n",
    "Secretflow supports `Logistic Regression` with mix data. The algorithm combines `H`omomorphic `E`ncryption and secure aggregation for better security, for more details please refer to [Federated Logistic Regression](../developer/algorithm/mix_lr.md)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialization"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Init secretflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(['alice', 'bob', 'carol', 'dave', 'eric'], address='local', num_cpus=64)\n",
    "\n",
    "alice, bob, carol, dave, eric = (\n",
    "    sf.PYU('alice'),\n",
    "    sf.PYU('bob'),\n",
    "    sf.PYU('carol'),\n",
    "    sf.PYU('dave'),\n",
    "    sf.PYU('eric'),\n",
    ")\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preparation"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use [brease canser](https://archive.ics.uci.edu/ml/datasets/breast+cancer+wisconsin+(diagnostic)) as our dataset.\n",
    "\n",
    "Let us build a mix partitioned data with this dataset. The partitions are as follows:\n",
    "\n",
    "|label|feature1\\~feature10|feature11\\~feature20|feature21\\~feature30|\n",
    "|---|---|---|--|\n",
    "|alice_y0|alice_x0|bob_x0|carol_x|\n",
    "|alice_y1|alice_x1|bob_x1|dave_x|\n",
    "|alice_y2|alice_x2|bob_x2|eric_x|\n",
    "\n",
    "Alice holds all label and all 1~10 features, bob holds all 11~20 fetures,\n",
    "carol/dave/eric hold a part of 21~30 features separately.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "features, label = load_breast_cancer(return_X_y=True, as_frame=True)\n",
    "features.iloc[:, :] = StandardScaler().fit_transform(features)\n",
    "label = label.to_frame()\n",
    "\n",
    "\n",
    "feat_list = [\n",
    "    features.iloc[:, :10],\n",
    "    features.iloc[:, 10:20],\n",
    "    features.iloc[:, 20:],\n",
    "]\n",
    "\n",
    "alice_y0, alice_y1, alice_y2 = label.iloc[0:200], label.iloc[200:400], label.iloc[400:]\n",
    "alice_x0, alice_x1, alice_x2 = (\n",
    "    feat_list[0].iloc[0:200, :],\n",
    "    feat_list[0].iloc[200:400, :],\n",
    "    feat_list[0].iloc[400:, :],\n",
    ")\n",
    "bob_x0, bob_x1, bob_x2 = (\n",
    "    feat_list[1].iloc[0:200, :],\n",
    "    feat_list[1].iloc[200:400, :],\n",
    "    feat_list[1].iloc[400:, :],\n",
    ")\n",
    "carol_x, dave_x, eric_x = (\n",
    "    feat_list[2].iloc[0:200, :],\n",
    "    feat_list[2].iloc[200:400, :],\n",
    "    feat_list[2].iloc[400:, :],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "\n",
    "tmp_dir = tempfile.mkdtemp()\n",
    "\n",
    "\n",
    "def filepath(filename):\n",
    "    return f'{tmp_dir}/{filename}'\n",
    "\n",
    "\n",
    "alice_y0_file, alice_y1_file, alice_y2_file = (\n",
    "    filepath('alice_y0'),\n",
    "    filepath('alice_y1'),\n",
    "    filepath('alice_y2'),\n",
    ")\n",
    "alice_x0_file, alice_x1_file, alice_x2_file = (\n",
    "    filepath('alice_x0'),\n",
    "    filepath('alice_x1'),\n",
    "    filepath('alice_x2'),\n",
    ")\n",
    "bob_x0_file, bob_x1_file, bob_x2_file = (\n",
    "    filepath('bob_x0'),\n",
    "    filepath('bob_x1'),\n",
    "    filepath('bob_x2'),\n",
    ")\n",
    "carol_x_file, dave_x_file, eric_x_file = (\n",
    "    filepath('carol_x'),\n",
    "    filepath('dave_x'),\n",
    "    filepath('eric_x'),\n",
    ")\n",
    "\n",
    "alice_x0.to_csv(alice_x0_file, index=False)\n",
    "alice_x1.to_csv(alice_x1_file, index=False)\n",
    "alice_x2.to_csv(alice_x2_file, index=False)\n",
    "bob_x0.to_csv(bob_x0_file, index=False)\n",
    "bob_x1.to_csv(bob_x1_file, index=False)\n",
    "bob_x2.to_csv(bob_x2_file, index=False)\n",
    "carol_x.to_csv(carol_x_file, index=False)\n",
    "dave_x.to_csv(dave_x_file, index=False)\n",
    "eric_x.to_csv(eric_x_file, index=False)\n",
    "\n",
    "alice_y0.to_csv(alice_y0_file, index=False)\n",
    "alice_y1.to_csv(alice_y1_file, index=False)\n",
    "alice_y2.to_csv(alice_y2_file, index=False)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now create `MixDataFrame` x and y for further usage.\n",
    "`MixDataFrame` is a list of `HDataFrame` or `VDataFrame`\n",
    "\n",
    "> you can read secretflow's [DataFrame](../user_guide/preprocessing/DataFrame.ipynb) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "vdf_x0 = sf.data.vertical.read_csv(\n",
    "    {alice: alice_x0_file, bob: bob_x0_file, carol: carol_x_file}\n",
    ")\n",
    "vdf_x1 = sf.data.vertical.read_csv(\n",
    "    {alice: alice_x1_file, bob: bob_x1_file, dave: dave_x_file}\n",
    ")\n",
    "vdf_x2 = sf.data.vertical.read_csv(\n",
    "    {alice: alice_x2_file, bob: bob_x2_file, eric: eric_x_file}\n",
    ")\n",
    "vdf_y0 = sf.data.vertical.read_csv({alice: alice_y0_file})\n",
    "vdf_y1 = sf.data.vertical.read_csv({alice: alice_y1_file})\n",
    "vdf_y2 = sf.data.vertical.read_csv({alice: alice_y2_file})\n",
    "\n",
    "\n",
    "x = sf.data.mix.MixDataFrame(partitions=[vdf_x0, vdf_x1, vdf_x2])\n",
    "y = sf.data.mix.MixDataFrame(partitions=[vdf_y0, vdf_y1, vdf_y2])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Construct `HEU` and `SecureAggregator` for further training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "from secretflow.security.aggregation import SecureAggregator\n",
    "import spu\n",
    "\n",
    "\n",
    "def heu_config(sk_keeper: str, evaluators: List[str]):\n",
    "    return {\n",
    "        'sk_keeper': {'party': sk_keeper},\n",
    "        'evaluators': [{'party': evaluator} for evaluator in evaluators],\n",
    "        'mode': 'PHEU',\n",
    "        'he_parameters': {\n",
    "            'schema': 'paillier',\n",
    "            'key_pair': {'generate': {'bit_size': 2048}},\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "heu0 = sf.HEU(heu_config('alice', ['bob', 'carol']), spu.spu_pb2.FM128)\n",
    "heu1 = sf.HEU(heu_config('alice', ['bob', 'dave']), spu.spu_pb2.FM128)\n",
    "heu2 = sf.HEU(heu_config('alice', ['bob', 'eric']), spu.spu_pb2.FM128)\n",
    "aggregator0 = SecureAggregator(alice, [alice, bob, carol])\n",
    "aggregator1 = SecureAggregator(alice, [alice, bob, dave])\n",
    "aggregator2 = SecureAggregator(alice, [alice, bob, eric])\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fit the  model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:MixLr epoch 0: loss = 0.22200048124132613\n",
      "INFO:root:MixLr epoch 1: loss = 0.10997288443236536\n",
      "INFO:root:MixLr epoch 2: loss = 0.08508413270494121\n",
      "INFO:root:MixLr epoch 3: loss = 0.07325763613227645\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "logging.root.setLevel(level=logging.INFO)\n",
    "\n",
    "from secretflow.ml.linear import FlLogisticRegressionMix\n",
    "\n",
    "model = FlLogisticRegressionMix()\n",
    "\n",
    "model.fit(\n",
    "    x,\n",
    "    y,\n",
    "    batch_size=64,\n",
    "    epochs=3,\n",
    "    learning_rate=0.1,\n",
    "    aggregators=[aggregator0, aggregator1, aggregator2],\n",
    "    heus=[heu0, heu1, heu2],\n",
    ")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us do predictions with the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc: 0.9875535119708261 , acc: 0.9384885764499121\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "y_pred = np.concatenate(sf.reveal(model.predict(x)))\n",
    "\n",
    "auc = roc_auc_score(label.values, y_pred)\n",
    "acc = np.mean((y_pred > 0.5) == label.values)\n",
    "print('auc:', auc, ', acc:', acc)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The end\n",
    "Clean temp files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "shutil.rmtree(tmp_dir, ignore_errors=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py3.8')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "66d1547304beaba725027c44e57cc46fc747862fe9496520910412a3187eb35f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
